{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c0a3331-48f3-424e-87af-1ab230025abd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from sklearn.metrics import f1_score\n",
    "#1000条数据试验词向量转换功能，转换为FastText需要的格式\n",
    "train_df = pd.read_csv('./train_set.csv', sep='\\t', nrows=1000)\n",
    "train_df['label_ft'] = '__label__' + train_df['label'].astype(str)\n",
    "train_df[['text','label_ft']].iloc[:-500].to_csv('train1000.csv', index=None, header=None, sep='\\t')\n",
    "\n",
    "import fasttext\n",
    "model1 = fasttext.train_unsupervised('train1000.csv', lr=0.1, wordNgrams=2, \n",
    "                                  verbose=2, minCount=1, epoch=2, loss=\"hs\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "571e3392-d458-416a-9af0-fc94a24f7da6",
   "metadata": {},
   "outputs": [],
   "source": [
    "model1.save_model(\"result1000.bin\")\n",
    "#cmd命令行执行python bin_to_vec.py result1000.bin < result1000.vec，转换为vec词向量\n",
    "\n",
    "model2 = fasttext.train_supervised('train1000.csv',pretrainedVectors='result1000.vec',lr=1.0, wordNgrams=2, \n",
    "                                  verbose=2, minCount=1, epoch=2, loss=\"hs\")\n",
    "#试验成功"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3ef33be3-b263-4e2a-91e4-a0443b9b407b",
   "metadata": {},
   "source": [
    "正式测试开始，选取15000条数据。前1w条训练，后5000条测试"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "f3b810f9-685f-4f1c-9f91-9eef0711ca86",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8271753451328839\n"
     ]
    }
   ],
   "source": [
    "#正常fasttext训练\n",
    "import pandas as pd\n",
    "from sklearn.metrics import f1_score\n",
    "\n",
    "# 转换为FastText需要的格式\n",
    "train_df = pd.read_csv('./train_set.csv', sep='\\t', nrows=15000)\n",
    "train_df['label_ft'] = '__label__' + train_df['label'].astype(str)\n",
    "train_df[['text','label_ft']].iloc[:-5000].to_csv('train10000.csv', index=None, header=None, sep='\\t')\n",
    "\n",
    "import fasttext\n",
    "model = fasttext.train_supervised('train.csv', lr=1.0, wordNgrams=2, \n",
    "                                  verbose=2, minCount=1, epoch=25, loss=\"hs\")\n",
    "\n",
    "val_pred = [model.predict(x)[0][0].split('__')[-1] for x in train_df.iloc[-5000:]['text']]\n",
    "print(f1_score(train_df['label'].values[-5000:].astype(str), val_pred, average='macro'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "4ba880a1-fa81-4bbb-8d51-40231746e33b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#先进行word2vec训练，使用全部15000数据\n",
    "train_df[['text','label_ft']].to_csv('train15000.csv', index=None, header=None, sep='\\t')\n",
    "model1 = fasttext.train_unsupervised('train15000.csv', lr=0.1, wordNgrams=2, \n",
    "                                  verbose=2, minCount=1, epoch=8, loss=\"hs\")\n",
    "#保存模型转为词向量\n",
    "model1.save_model(\"word15000.bin\")\n",
    "#cmd命令行执行python bin_to_vec.py word15000.bin < word15000.vec，转换为vec词向量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "77a1ed1a-3220-49e4-b38d-682a78609a82",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8502375081398378\n"
     ]
    }
   ],
   "source": [
    "#fasttext进行1w条数据训练\n",
    "\"\"\"import fasttext\n",
    "import pandas as pd\n",
    "from sklearn.metrics import f1_score\n",
    "model2 = fasttext.train_supervised('train.csv',pretrainedVectors='word15000.vec',lr=1.0, wordNgrams=2, \n",
    "                                 verbose=2, minCount=1, epoch=16, loss=\"hs\")\"\"\"\n",
    "#预测结果\n",
    "train_df = pd.read_csv('./train_set.csv', sep='\\t', nrows=15000)\n",
    "val_pred = [model2.predict(x)[0][0].split('__')[-1] for x in train_df.iloc[-5000:]['text']]\n",
    "print(f1_score(train_df['label'].values[-5000:].astype(str), val_pred, average='macro'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5c435fa-e68b-46b3-b279-0a3ccb65db0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "#首尾截断实验效果\n",
    "#准备将text文本首尾截断，各取255tokens\n",
    "def slipt2(x):\n",
    "  ls=x.split(' ')\n",
    "  le=len(ls)\n",
    "  if le<201:\n",
    "    return x\n",
    "  else:\n",
    "    return ' '.join(ls[:100]+ls[-100:])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "0eb385a3-ae69-443c-ae16-a8530a647690",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8304409751791343\n"
     ]
    }
   ],
   "source": [
    "#首尾截断进行训练\n",
    "train_df['summary']=train_df['text'].apply(lambda x:slipt2(x))\n",
    "train_df[['summary','label_ft']].iloc[:-5000].to_csv('train_summary10000.csv', index=None, header=None, sep='\\t')\n",
    "\n",
    "model3 = fasttext.train_supervised('train_summary10000.csv',pretrainedVectors='word15000.vec',lr=1.0, wordNgrams=2, \n",
    "                                  verbose=2, minCount=1, epoch=16, loss=\"hs\")\n",
    "#预测结果\n",
    "val_pred = [model3.predict(x)[0][0].split('__')[-1] for x in train_df.iloc[-5000:]['summary']]\n",
    "print(f1_score(train_df['label'].values[-5000:].astype(str), val_pred, average='macro'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "046154ba-04ad-4977-80b1-4ff25e18a6f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "model2.save_model(\"fastword15000.bin\")\n",
    "#cmd命令行执行python bin_to_vec.py fastword15000.bin < fastword15000.vec，转换为vec词向量\n",
    "#对比fasttext_word15000.vec和word15000.vec词向量，两者都是（5516 100），但是词向量已经不一样了。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e6fff999-7978-4289-8fb0-5a691621a6c0",
   "metadata": {},
   "source": [
    "------------------------------------------分割线啊---------------------------------"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "16fb76e2-f792-4ba4-b967-4012eab0b1cb",
   "metadata": {},
   "source": [
    "开始正式训练\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8e96603-b38a-4daf-815b-0dd317656566",
   "metadata": {},
   "outputs": [],
   "source": [
    "#读取训练测试集数据\n",
    "import pandas as pd\n",
    "from sklearn.metrics import f1_score\n",
    "\n",
    "# 转换为FastText需要的格式\n",
    "train_df = pd.read_csv('./train_set.csv', sep='\\t')\n",
    "train_df['label_ft'] = '__label__' + train_df['label'].astype(str)\n",
    "train_df[['text','label_ft']].to_csv('train_20w.csv', index=None, header=None, sep='\\t')\n",
    "\n",
    "test_df = pd.read_csv('./test_a.csv', sep='\\t')\n",
    "df=pd.concat([train_df,test_df])\n",
    "df[['text']].to_csv('train_25w.csv', index=None, header=None, sep='\\t')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b27d8b42-65ea-481f-aa3d-5e9961245e25",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "96096586-abd7-4e73-969a-c43adae7f94f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import fasttext\n",
    "#word2vec进行train+test词向量训练\n",
    "model1 = fasttext.train_unsupervised('train_25w.csv', lr=0.1, wordNgrams=2, \n",
    "                                  verbose=2, minCount=1, epoch=8, loss=\"hs\")\n",
    "\n",
    "model1.save_model(\"word_25w.bin\")\n",
    "#python bin_to_vec.py word_25w.bin > word_25w.vec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8410fe11-ae9d-4947-94b8-ff2d0c3b44c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "#fasttext进行有监督训练\n",
    "#model2 = fasttext.train_supervised('train_20w.csv',pretrainedVectors='word_25w.vec',lr=0.8, wordNgrams=2, \n",
    "#                                 verbose=2, minCount=1, epoch=18, loss=\"hs\")\n",
    "import pandas as pd\n",
    "test_df = pd.read_csv('./test_a.csv', sep='\\t')\n",
    "test_pred = [model2.predict(x)[0][0].split('__')[-1] for x in test_df['text']]\n",
    "\n",
    "pd.DataFrame({'label':test_pred}).to_csv('word_fast.csv',index=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "id": "56ff8d61-c32a-4fab-8269-341aa3fde874",
   "metadata": {},
   "outputs": [],
   "source": [
    "#首尾截断实验效果\n",
    "#准备将text文本首尾截断，各取100tokens\n",
    "def slipt2(x):\n",
    "    ls=x.split(' ')\n",
    "    le=len(ls)\n",
    "    if le<301:\n",
    "        return x\n",
    "    else:\n",
    "        return ' '.join(ls[:150]+ls[-150:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "id": "bdca6a7e-b275-43d5-a62c-a47672e102bb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>label</th>\n",
       "      <th>text</th>\n",
       "      <th>label_ft</th>\n",
       "      <th>summary</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2</td>\n",
       "      <td>2967 6758 339 2021 1854 3731 4109 3792 4149 15...</td>\n",
       "      <td>__label__2</td>\n",
       "      <td>2967 6758 339 2021 1854 3731 4109 3792 4149 15...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>11</td>\n",
       "      <td>4464 486 6352 5619 2465 4802 1452 3137 5778 54...</td>\n",
       "      <td>__label__11</td>\n",
       "      <td>4464 486 6352 5619 2465 4802 1452 3137 5778 54...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3</td>\n",
       "      <td>7346 4068 5074 3747 5681 6093 1777 2226 7354 6...</td>\n",
       "      <td>__label__3</td>\n",
       "      <td>7346 4068 5074 3747 5681 6093 1777 2226 7354 6...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2</td>\n",
       "      <td>7159 948 4866 2109 5520 2490 211 3956 5520 549...</td>\n",
       "      <td>__label__2</td>\n",
       "      <td>7159 948 4866 2109 5520 2490 211 3956 5520 549...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>3</td>\n",
       "      <td>3646 3055 3055 2490 4659 6065 3370 5814 2465 5...</td>\n",
       "      <td>__label__3</td>\n",
       "      <td>3646 3055 3055 2490 4659 6065 3370 5814 2465 5...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>199995</th>\n",
       "      <td>2</td>\n",
       "      <td>307 4894 7539 4853 5330 648 6038 4409 3764 603...</td>\n",
       "      <td>__label__2</td>\n",
       "      <td>307 4894 7539 4853 5330 648 6038 4409 3764 603...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>199996</th>\n",
       "      <td>2</td>\n",
       "      <td>3792 2983 355 1070 4464 5050 6298 3782 3130 68...</td>\n",
       "      <td>__label__2</td>\n",
       "      <td>3792 2983 355 1070 4464 5050 6298 3782 3130 68...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>199997</th>\n",
       "      <td>11</td>\n",
       "      <td>6811 1580 7539 1252 1899 5139 1386 3870 4124 1...</td>\n",
       "      <td>__label__11</td>\n",
       "      <td>6811 1580 7539 1252 1899 5139 1386 3870 4124 1...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>199998</th>\n",
       "      <td>2</td>\n",
       "      <td>6405 3203 6644 983 794 1913 1678 5736 1397 191...</td>\n",
       "      <td>__label__2</td>\n",
       "      <td>6405 3203 6644 983 794 1913 1678 5736 1397 191...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>199999</th>\n",
       "      <td>3</td>\n",
       "      <td>4350 3878 3268 1699 6909 5505 2376 2465 6088 2...</td>\n",
       "      <td>__label__3</td>\n",
       "      <td>4350 3878 3268 1699 6909 5505 2376 2465 6088 2...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>200000 rows × 4 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "        label                                               text     label_ft  \\\n",
       "0           2  2967 6758 339 2021 1854 3731 4109 3792 4149 15...   __label__2   \n",
       "1          11  4464 486 6352 5619 2465 4802 1452 3137 5778 54...  __label__11   \n",
       "2           3  7346 4068 5074 3747 5681 6093 1777 2226 7354 6...   __label__3   \n",
       "3           2  7159 948 4866 2109 5520 2490 211 3956 5520 549...   __label__2   \n",
       "4           3  3646 3055 3055 2490 4659 6065 3370 5814 2465 5...   __label__3   \n",
       "...       ...                                                ...          ...   \n",
       "199995      2  307 4894 7539 4853 5330 648 6038 4409 3764 603...   __label__2   \n",
       "199996      2  3792 2983 355 1070 4464 5050 6298 3782 3130 68...   __label__2   \n",
       "199997     11  6811 1580 7539 1252 1899 5139 1386 3870 4124 1...  __label__11   \n",
       "199998      2  6405 3203 6644 983 794 1913 1678 5736 1397 191...   __label__2   \n",
       "199999      3  4350 3878 3268 1699 6909 5505 2376 2465 6088 2...   __label__3   \n",
       "\n",
       "                                                  summary  \n",
       "0       2967 6758 339 2021 1854 3731 4109 3792 4149 15...  \n",
       "1       4464 486 6352 5619 2465 4802 1452 3137 5778 54...  \n",
       "2       7346 4068 5074 3747 5681 6093 1777 2226 7354 6...  \n",
       "3       7159 948 4866 2109 5520 2490 211 3956 5520 549...  \n",
       "4       3646 3055 3055 2490 4659 6065 3370 5814 2465 5...  \n",
       "...                                                   ...  \n",
       "199995  307 4894 7539 4853 5330 648 6038 4409 3764 603...  \n",
       "199996  3792 2983 355 1070 4464 5050 6298 3782 3130 68...  \n",
       "199997  6811 1580 7539 1252 1899 5139 1386 3870 4124 1...  \n",
       "199998  6405 3203 6644 983 794 1913 1678 5736 1397 191...  \n",
       "199999  4350 3878 3268 1699 6909 5505 2376 2465 6088 2...  \n",
       "\n",
       "[200000 rows x 4 columns]"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "id": "7c50cf8a-57dd-4a67-9168-4e2fdd8821f6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>text</th>\n",
       "      <th>summary</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>5399 3117 1070 4321 4568 2621 5466 3772 4516 2...</td>\n",
       "      <td>5399 3117 1070 4321 4568 2621 5466 3772 4516 2...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2491 4109 1757 7539 648 3695 3038 4490 23 7019...</td>\n",
       "      <td>2491 4109 1757 7539 648 3695 3038 4490 23 7019...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2673 5076 6835 2835 5948 5677 3247 4124 2465 5...</td>\n",
       "      <td>2673 5076 6835 2835 5948 5677 3247 4124 2465 5...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4562 4893 2210 4761 3659 1324 2595 5949 4583 2...</td>\n",
       "      <td>4562 4893 2210 4761 3659 1324 2595 5949 4583 2...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4269 7134 2614 1724 4464 1324 3370 3370 2106 2...</td>\n",
       "      <td>4269 7134 2614 1724 4464 1324 3370 3370 2106 2...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>49995</th>\n",
       "      <td>3725 4498 2282 1647 6293 4245 4498 3615 1141 2...</td>\n",
       "      <td>3725 4498 2282 1647 6293 4245 4498 3615 1141 2...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>49996</th>\n",
       "      <td>4811 465 3800 1394 3038 2376 2327 5165 3070 57...</td>\n",
       "      <td>4811 465 3800 1394 3038 2376 2327 5165 3070 57...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>49997</th>\n",
       "      <td>5338 1952 3117 4109 299 6656 6654 3792 6831 21...</td>\n",
       "      <td>5338 1952 3117 4109 299 6656 6654 3792 6831 21...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>49998</th>\n",
       "      <td>893 3469 5775 584 2490 4223 6569 6663 2124 168...</td>\n",
       "      <td>893 3469 5775 584 2490 4223 6569 6663 2124 168...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>49999</th>\n",
       "      <td>2400 4409 4412 2210 5122 4464 7186 2465 1327 9...</td>\n",
       "      <td>2400 4409 4412 2210 5122 4464 7186 2465 1327 9...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>50000 rows × 2 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                    text  \\\n",
       "0      5399 3117 1070 4321 4568 2621 5466 3772 4516 2...   \n",
       "1      2491 4109 1757 7539 648 3695 3038 4490 23 7019...   \n",
       "2      2673 5076 6835 2835 5948 5677 3247 4124 2465 5...   \n",
       "3      4562 4893 2210 4761 3659 1324 2595 5949 4583 2...   \n",
       "4      4269 7134 2614 1724 4464 1324 3370 3370 2106 2...   \n",
       "...                                                  ...   \n",
       "49995  3725 4498 2282 1647 6293 4245 4498 3615 1141 2...   \n",
       "49996  4811 465 3800 1394 3038 2376 2327 5165 3070 57...   \n",
       "49997  5338 1952 3117 4109 299 6656 6654 3792 6831 21...   \n",
       "49998  893 3469 5775 584 2490 4223 6569 6663 2124 168...   \n",
       "49999  2400 4409 4412 2210 5122 4464 7186 2465 1327 9...   \n",
       "\n",
       "                                                 summary  \n",
       "0      5399 3117 1070 4321 4568 2621 5466 3772 4516 2...  \n",
       "1      2491 4109 1757 7539 648 3695 3038 4490 23 7019...  \n",
       "2      2673 5076 6835 2835 5948 5677 3247 4124 2465 5...  \n",
       "3      4562 4893 2210 4761 3659 1324 2595 5949 4583 2...  \n",
       "4      4269 7134 2614 1724 4464 1324 3370 3370 2106 2...  \n",
       "...                                                  ...  \n",
       "49995  3725 4498 2282 1647 6293 4245 4498 3615 1141 2...  \n",
       "49996  4811 465 3800 1394 3038 2376 2327 5165 3070 57...  \n",
       "49997  5338 1952 3117 4109 299 6656 6654 3792 6831 21...  \n",
       "49998  893 3469 5775 584 2490 4223 6569 6663 2124 168...  \n",
       "49999  2400 4409 4412 2210 5122 4464 7186 2465 1327 9...  \n",
       "\n",
       "[50000 rows x 2 columns]"
      ]
     },
     "execution_count": 58,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "d0e5cd7e-9292-4410-9b01-5ca40d75c478",
   "metadata": {},
   "outputs": [],
   "source": [
    "#首尾截断进行训练\n",
    "#train_df = pd.read_csv('./train_set.csv', sep='\\t')\n",
    "#train_df['label_ft'] = '__label__' + train_df['label'].astype(str)\n",
    "\n",
    "train_df['summary']=train_df['text'].apply(lambda x:slipt2(x))\n",
    "\n",
    "train_df[['summary','label_ft']].to_csv('train_summary_20w.csv', index=None, header=None, sep='\\t')\n",
    "\n",
    "model3 = fasttext.train_supervised('train_summary_20w.csv',pretrainedVectors='word_25w.vec',lr=0.8, wordNgrams=3, \n",
    "                                  verbose=2, minCount=1, epoch=18, loss=\"softmax\")\n",
    "\n",
    "val_pred = [model3.predict(x)[0][0].split('__')[-1] for x in train_df.iloc[-10000:]['summary']]\n",
    "\n",
    "from sklearn.metrics import f1_score\n",
    "print(f1_score(train_df['label'].values[-10000:].astype(str), val_pred, average='macro'))\n",
    "\n",
    "#预测结果\n",
    "test_df['summary']=test_df['text'].apply(lambda x:slipt2(x))\n",
    "test_pred = [model3.predict(x)[0][0].split('__')[-1] for x in test_df['summary']]\n",
    "\n",
    "pd.DataFrame({'label':test_pred}).to_csv('word_fast_cut_best.csv',index=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "101820f7-bab8-4402-8bfd-907694c03ccf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9157807172377387\n"
     ]
    }
   ],
   "source": [
    "#看看不截断的f1分数\n",
    "\"\"\"train_df[['text','label_ft']].iloc[:-10000].to_csv('train_test_19w.csv', index=None, header=None, sep='\\t')\n",
    "\n",
    "model3 = fasttext.train_supervised('train_test_19w.csv',pretrainedVectors='word_25w.vec',lr=0.8, wordNgrams=2, \n",
    "                                  verbose=2, minCount=1, epoch=18, loss=\"hs\")\"\"\"\n",
    "val_pred = [model3.predict(x)[0][0].split('__')[-1] for x in train_df.iloc[-10000:]['text']]\n",
    "\n",
    "from sklearn.metrics import f1_score\n",
    "print(f1_score(train_df['label'].values[-10000:].astype(str), val_pred, average='macro'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f9f05c10-a9f1-4120-9763-9c24c31f873e",
   "metadata": {},
   "outputs": [],
   "source": [
    "#正常fasttext应用以上参数：\n",
    "import pandas as pd\n",
    "from sklearn.metrics import f1_score\n",
    "\n",
    "# 转换为FastText需要的格式\n",
    "train_df = pd.read_csv('./train_set.csv', sep='\\t')\n",
    "train_df['label_ft'] = '__label__' + train_df['label'].astype(str)\n",
    "train_df[['text','label_ft']].to_csv('train_20w.csv', index=None, header=None, sep='\\t')\n",
    "\n",
    "import fasttext\n",
    "model = fasttext.train_supervised('train_20w.csv', lr=0.8, wordNgrams=3, \n",
    "                                  verbose=2, minCount=1, epoch=25, loss=\"softmax\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ad421e7e-0989-4efc-b0fb-e51910318b69",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_df = pd.read_csv('./test_a.csv', sep='\\t')\n",
    "test_pred = [model.predict(x)[0][0].split('__')[-1] for x in test_df['text']]\n",
    "\n",
    "pd.DataFrame({'label':test_pred}).to_csv('fast2.csv',index=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e99d755a-cc37-4c0c-a014-716a81b46650",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
